/* 
 * Copyright (c) 2011, Alex Krizhevsky (akrizhevsky@gmail.com)
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without modification,
 * are permitted provided that the following conditions are met:
 *
 * - Redistributions of source code must retain the above copyright notice,
 *  this list of conditions and the following disclaimer.
 * 
 * - Redistributions in binary form must reproduce the above copyright notice,
 *  this list of conditions and the following disclaimer in the documentation
 *  and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 * EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include "cudamat.cuh"


extern "C" void _filterActsCu(cudamat* images, cudamat* filters, cudamat* targets, int numModulesX, int paddingStart, int moduleStride,
         int numImgColors, int numGroups, float scaleTargets, float scaleOutput, bool conv);
extern "C" void _imgActsCu(cudamat* hidActs, cudamat* filters, cudamat* targets, int imgSize, int paddingStart, int moduleStride, 
       int numImgColors, int numGroups, float scaleTargets, float scaleOutput, bool conv);
extern "C" void _weightActsCu(cudamat* images, cudamat* hidActs, cudamat* targets, int numModulesX, int filterSize, int paddingStart, int moduleStride,
         int numImgColors, int numGroups, int partialSum, float scaleTargets, float scaleOutput);


void convResponseNormCu(cudamat* images, cudamat* denoms, cudamat* target, int numFilters, int sizeX, float addScale, float powScale);

void convResponseNormUndoCu(cudamat* outGrads, cudamat* denoms, cudamat* inputs, cudamat* acts, cudamat* target, int numFilters,
              int sizeX, float addScale, float powScale, float scaleTargets, float scaleOutput);

void convContrastNormCu(cudamat* images, cudamat* meanDiffs, cudamat* denoms, cudamat* target, int numFilters, int sizeX, float addScale, float powScale);
void convContrastNormUndoCu(cudamat* outGrads, cudamat* denoms, cudamat* meanDiffs, cudamat* acts, cudamat* target, int numFilters,
             int sizeX, float addScale, float powScale, float scaleTargets, float scaleOutput);

class AvgPooler {
private:
  float _num;
public:
  AvgPooler(float num) : _num(num) {
  }
  __device__ inline float operator()(const float a, const float b) const {
    return a + b;
  }
  __device__ inline float getBaseValue() const {
    return 0;
  }
  __device__ inline float output(const float a) const {
    return a / _num;
  }
};

class MaxPooler {
public:
  __device__ inline float operator()(const float a, const float b) const {
    return a > b ? a : b;
  }
  __device__ inline float getBaseValue() const {
    return -2e38; 
  }
  __device__ inline float output(const float a) const {
    return a;
  }
};

class ProbMaxPooler {
public:
  __device__ inline float operator()(const float a, const float b, const float r1, const float r2) const {
    return a * r1 > b * r2 ? 0 : 1;
  }
  __device__ inline float getBaseValue() const {
    return -2e38; 
  }
  __device__ inline float output(const float a) const {
    return a;
  }
};

/*
 * Block size B_YxB_X
 * blockIdx.x determines output.x, image idx in batches of B_X*imgsPerThread
 * blockIdx.y determines output.y, filter idx in batches of B_Y*filtersPerThread
 * 
 * So each block does one output for some number of images/filters.
 * 
 * threadIdx.x determines img idx
 * threadIdx.y determines filter idx
 * 
 * imgs:    (numFilters, imgPixels, numImages)
 * rnd :    (numFilters, imgPixels, numImages)
 * target:   (numFilters, numOutputs, numImages)
 * 
 * numImages must be divisible by B_X*imgsPerThread if checkCaseBounds is false
 */

template<class Agg, int B_Y, int B_X, int imgsPerThread, int filtersPerThread, bool checkCaseBounds>
__global__ void kLocalProbPool(float* imgs, float* rnd, float* target, const int imgSize, const int numFilters,
              const int numImages, const int subsX, const int startX, const int strideX,
              const int outputsX, Agg agg) {
  const int numImgBlocks = DIVUP(numImages,B_X*imgsPerThread);
  const int numFilterBlocks = DIVUP(numFilters, B_Y*filtersPerThread);
  const int outputIdxX = blockIdx.x / numImgBlocks;
  const int outputIdxY = blockIdx.y / numFilterBlocks;
  const int blockImgIdx = (blockIdx.x % numImgBlocks) * B_X * imgsPerThread;
  const int blockFilterIdx = (blockIdx.y % numFilterBlocks) * B_Y * filtersPerThread;
  const int myFilterIdx = (blockFilterIdx + threadIdx.y*filtersPerThread);
  if (myFilterIdx >= numFilters) {
    return;
  }
  
  const int outputIdx = outputIdxY * outputsX + outputIdxX;
  const int numOutputs = outputsX * outputsX;
  const int imgPixels = imgSize * imgSize;
  
  const int startImgPxX = startX + outputIdxX * strideX;
  const int startImgPxY = startX + outputIdxY * strideX;
  const int imgIdx = blockImgIdx + threadIdx.x;
  
  imgs += myFilterIdx * imgPixels * numImages + imgIdx;
  rnd += myFilterIdx * imgPixels * numImages + imgIdx;
  target += (myFilterIdx * numOutputs + outputIdx) * numImages + imgIdx;
  
  float prod[filtersPerThread][imgsPerThread];
  float rnd_used[filtersPerThread][imgsPerThread];
  #pragma unroll
  for (int f = 0; f < filtersPerThread; f++) {
    #pragma unroll
    for (int i = 0; i < imgsPerThread; i++) {
      prod[f][i] = agg.getBaseValue(); 
      rnd_used[f][i] = agg.getBaseValue(); 
    }
  }
  
  const int loopStartY = MAX(0, startImgPxY);
  const int loopStartX = MAX(0, startImgPxX);
  const int loopEndY = MIN(imgSize, startImgPxY + subsX);
  const int loopEndX = MIN(imgSize, startImgPxX + subsX);
  for (int y = loopStartY; y < loopEndY; y++) {
    for (int x = loopStartX; x < loopEndX; x++) {
      const int imgPx = y * imgSize + x;
      #pragma unroll
      for (int i = 0; i < imgsPerThread; i++) {
        if (!checkCaseBounds || imgIdx + i * B_X < numImages) {
          #pragma unroll
          for (int f = 0; f < filtersPerThread; f++) {
            const int loc = (f * imgPixels + imgPx) * numImages + i * B_X;
            const int res = agg(prod[f][i], imgs[loc], rnd_used[f][i], rnd[loc]);
            prod[f][i] = res == 0 ? prod[f][i] : imgs[loc];
            rnd_used[f][i] = res == 0 ? rnd_used[f][i] : rnd[loc];
          }
        }
      }
    }
  }
  
  #pragma unroll
  for (int i = 0; i < imgsPerThread; i++) {
    if (!checkCaseBounds || imgIdx + i * B_X < numImages) {
      #pragma unroll
      for (int f = 0; f < filtersPerThread; f++) {
        target[f * numOutputs * numImages + i * B_X] = agg.output(prod[f][i]); 
      }
    }
  }
}


/*
 * Block size B_YxB_X
 * blockIdx.x determines output.x, image idx in batches of B_X*imgsPerThread
 * blockIdx.y determines output.y, filter idx in batches of B_Y*filtersPerThread
 * 
 * So each block does one output for some number of images/filters.
 * 
 * threadIdx.x determines img idx
 * threadIdx.y determines filter idx
 * 
 * imgs:    (numFilters, imgPixels, numImages)
 * target:   (numFilters, numOutputs, numImages)
 * 
 * numImages must be divisible by B_X*imgsPerThread if checkCaseBounds is false
 */

template<class Agg, int B_Y, int B_X, int imgsPerThread, int filtersPerThread, bool checkCaseBounds>
__global__ void kLocalPool(float* imgs, float* target, const int imgSize, const int numFilters,
              const int numImages, const int subsX, const int startX, const int strideX,
              const int outputsX, Agg agg) {
  const int numImgBlocks = DIVUP(numImages,B_X*imgsPerThread);
  const int numFilterBlocks = DIVUP(numFilters, B_Y*filtersPerThread);
  const int outputIdxX = blockIdx.x / numImgBlocks;
  const int outputIdxY = blockIdx.y / numFilterBlocks;
  const int blockImgIdx = (blockIdx.x % numImgBlocks) * B_X * imgsPerThread;
  const int blockFilterIdx = (blockIdx.y % numFilterBlocks) * B_Y * filtersPerThread;
  const int myFilterIdx = (blockFilterIdx + threadIdx.y*filtersPerThread);
  if (myFilterIdx >= numFilters) {
    return;
  }
  
  const int outputIdx = outputIdxY * outputsX + outputIdxX;
  const int numOutputs = outputsX * outputsX;
  const int imgPixels = imgSize * imgSize;
  
  const int startImgPxX = startX + outputIdxX * strideX;
  const int startImgPxY = startX + outputIdxY * strideX;
  const int imgIdx = blockImgIdx + threadIdx.x;
  
  imgs += myFilterIdx * imgPixels * numImages + imgIdx;
  target += (myFilterIdx * numOutputs + outputIdx) * numImages + imgIdx;
  
  float prod[filtersPerThread][imgsPerThread];
  #pragma unroll
  for (int f = 0; f < filtersPerThread; f++) {
    #pragma unroll
    for (int i = 0; i < imgsPerThread; i++) {
      prod[f][i] = agg.getBaseValue(); 
    }
  }
  
  const int loopStartY = MAX(0, startImgPxY);
  const int loopStartX = MAX(0, startImgPxX);
  const int loopEndY = MIN(imgSize, startImgPxY + subsX);
  const int loopEndX = MIN(imgSize, startImgPxX + subsX);
  for (int y = loopStartY; y < loopEndY; y++) {
    for (int x = loopStartX; x < loopEndX; x++) {
      const int imgPx = y * imgSize + x;
      #pragma unroll
      for (int i = 0; i < imgsPerThread; i++) {
        if (!checkCaseBounds || imgIdx + i * B_X < numImages) {
          #pragma unroll
          for (int f = 0; f < filtersPerThread; f++) {
            prod[f][i] = agg(prod[f][i], imgs[(f * imgPixels + imgPx) * numImages + i * B_X]);
          }
        }
      }
    }
  }
  
  #pragma unroll
  for (int i = 0; i < imgsPerThread; i++) {
    if (!checkCaseBounds || imgIdx + i * B_X < numImages) {
      #pragma unroll
      for (int f = 0; f < filtersPerThread; f++) {
        target[f * numOutputs * numImages + i * B_X] = agg.output(prod[f][i]); 
      }
    }
  }
}


/*
 * Block size 16xB_X
 * blockIdx.x determines 4x4 pixel.x region, image idx in batches of B_X*imgsPerThread
 * blockIdx.y determines 4x4 pixel.y region, filter idx in batches of filtersPerThread
 * 
 * So each block does a 4x4 region for some number of images/filters.
 * 
 * threadIdx.x determines img idx
 * threadIdx.y determines pixel idx
 * 
 * imgs:    (numFilters, imgPixels, numImages)
 * target:   (numFilters, numOutputs, numImages)
 * 
 * B_X one of 8, 16, 32
 * imgsPerThread one of 1, 2, 4, 8, 16
 * 
 * B_XximgsPerThread MUST be divisible by 32.
 * Number of filters MUST be divisible by filtersPerThread.
 * 
 * numImages must be divisible by B_X*imgsPerThread if checkCaseBounds is false
 * 
 * Final write-out will not be fully coalesced unless B_X is 32. But there's a lot more
 * reading than writing here, and the reading is all coalesced, so it should be OK.
 * 
 * To be used when the stride is 1 and the pooling region is fairly large.
 */
template<class Agg, int B_X, int imgsPerThread, int filtersPerThread, bool checkCaseBounds>
__global__ void kLocalProbPool2(float* imgs, float* rnd, float* target, const int imgSize, const int numFilters,
              const int numImages, const int subsX, const int startX,
              const int outputsX, Agg agg) {
  __shared__ float shImgs[filtersPerThread][B_X*imgsPerThread];
  __shared__ float shRnd[filtersPerThread][B_X*imgsPerThread];
  const int numImgBlocks = DIVUP(numImages,B_X*imgsPerThread);
  const int numFilterBlocks = numFilters/(filtersPerThread);
  const int blockOutputX = 4*(blockIdx.x / numImgBlocks);
  const int blockOutputY = 4*(blockIdx.y / numFilterBlocks);
  const int blockImgIdx = (blockIdx.x % numImgBlocks) * B_X * imgsPerThread;
  const int blockFilterIdx = (blockIdx.y % numFilterBlocks) * filtersPerThread;
  
//  const int blockOutputIdx = blockOutputY * outputsX + blockOutputX;
  const int numOutputs = outputsX * outputsX;
  const int imgPixels = imgSize * imgSize;
  
  const int tidx = threadIdx.y * B_X + threadIdx.x;
  const int loadY = tidx / 32, loadX = tidx % 32;
  
  const int myX = threadIdx.y % 4;
  const int myY = threadIdx.y / 4;
  
  const int myOutputIdxY = blockOutputY + myY;
  const int myOutputIdxX = blockOutputX + myX;
  const int myOutputIdx = myOutputIdxY * outputsX + myOutputIdxX;
  
  const int startImgPxX = startX + blockOutputX;
  const int startImgPxY = startX + blockOutputY;
  const int endImgPxX = startImgPxX + subsX;
  const int endImgPxY = startImgPxY + subsX;
  
  const int myStartImgPxY = startImgPxY + myY;
  const int myStartImgPxX = startImgPxX + myX;
  const int myEndImgPxY = endImgPxY + myY;
  const int myEndImgPxX = endImgPxX + myX;

  const int loopStartY = MAX(startImgPxY, 0);
  const int loopStartX = MAX(startImgPxX, 0);
  const int loopEndY = MIN(imgSize, endImgPxY + 3);
  const int loopEndX = MIN(imgSize, endImgPxX + 3);

  const int imgIdx = blockImgIdx + threadIdx.x;
  
  imgs += (blockFilterIdx + loadY) * imgPixels * numImages + blockImgIdx + loadX;
  rnd += (blockFilterIdx + loadY) * imgPixels * numImages + blockImgIdx + loadX;
  target += (blockFilterIdx * numOutputs + myOutputIdx) * numImages + imgIdx;
  
  float prod[filtersPerThread][imgsPerThread];
  float rnd_used[filtersPerThread][imgsPerThread];
  #pragma unroll
  for (int f = 0; f < filtersPerThread; f++) {
    #pragma unroll
    for (int i = 0; i < imgsPerThread; i++) {
      prod[f][i] = agg.getBaseValue(); 
      rnd_used[f][i] = agg.getBaseValue(); 
    }
  }

  for (int y = loopStartY; y < loopEndY; y++) {
    const bool isInY = y >= myStartImgPxY && y < myEndImgPxY ;
    for (int x = loopStartX; x < loopEndX; x++) {
      // Load a pixel
      const int px = y * imgSize + x;
      #pragma unroll
      for (int ly = 0; ly < filtersPerThread; ly += B_X/2) {
        if (filtersPerThread % (B_X/2) == 0 || ly + loadY < filtersPerThread) {
          #pragma unroll
          for (int lx = 0; lx < B_X*imgsPerThread; lx += 32) {
            if (!checkCaseBounds || lx + loadX + blockImgIdx < numImages) {
              shImgs[ly + loadY][lx + loadX] = imgs[(ly * imgPixels + px) * numImages + lx];
              shRnd[ly + loadY][lx + loadX] = rnd[(ly * imgPixels + px) * numImages + lx];
            }
          }
        }
      }
      __syncthreads();

      // Is this pixel in my region?
      if (isInY && x >= myStartImgPxX && x < myEndImgPxX) {
        #pragma unroll
        for (int i = 0; i < imgsPerThread; i++) {
          if (!checkCaseBounds || imgIdx + i * B_X < numImages) {
            #pragma unroll
            for (int f = 0; f < filtersPerThread; f++) {
              const int loc = threadIdx.x + i * B_X;
              const int res = agg(prod[f][i], shImgs[f][loc], rnd_used[f][i], shRnd[f][loc]);
              prod[f][i] = res == 0 ? prod[f][i] : shImgs[f][loc];
              rnd_used[f][i] = res == 0 ? rnd_used[f][i] : shRnd[f][loc];
            }
          }
        }
      }
      __syncthreads();

    }
  }
  if (myOutputIdxY < outputsX && myOutputIdxX < outputsX) {
    #pragma unroll
    for (int i = 0; i < imgsPerThread; i++) {
      if (!checkCaseBounds || imgIdx + i * B_X < numImages) {
        #pragma unroll
        for (int f = 0; f < filtersPerThread; f++) {
          target[f * numOutputs * numImages + i * B_X] = agg.output(prod[f][i]); 
        }
      }
    }
  }
}


/*
 * Block size 16xB_X
 * blockIdx.x determines 4x4 pixel.x region, image idx in batches of B_X*imgsPerThread
 * blockIdx.y determines 4x4 pixel.y region, filter idx in batches of filtersPerThread
 * 
 * So each block does a 4x4 region for some number of images/filters.
 * 
 * threadIdx.x determines img idx
 * threadIdx.y determines pixel idx
 * 
 * imgs:    (numFilters, imgPixels, numImages)
 * target:   (numFilters, numOutputs, numImages)
 * 
 * B_X one of 8, 16, 32
 * imgsPerThread one of 1, 2, 4, 8, 16
 * 
 * B_XximgsPerThread MUST be divisible by 32.
 * Number of filters MUST be divisible by filtersPerThread.
 * 
 * numImages must be divisible by B_X*imgsPerThread if checkCaseBounds is false
 * 
 * Final write-out will not be fully coalesced unless B_X is 32. But there's a lot more
 * reading than writing here, and the reading is all coalesced, so it should be OK.
 * 
 * To be used when the stride is 1 and the pooling region is fairly large.
 */
template<class Agg, int B_X, int imgsPerThread, int filtersPerThread, bool checkCaseBounds>
__global__ void kLocalPool2(float* imgs, float* target, const int imgSize, const int numFilters,
              const int numImages, const int subsX, const int startX,
              const int outputsX, Agg agg) {
  __shared__ float shImgs[filtersPerThread][B_X*imgsPerThread];
  const int numImgBlocks = DIVUP(numImages,B_X*imgsPerThread);
  const int numFilterBlocks = numFilters/(filtersPerThread);
  const int blockOutputX = 4*(blockIdx.x / numImgBlocks);
  const int blockOutputY = 4*(blockIdx.y / numFilterBlocks);
  const int blockImgIdx = (blockIdx.x % numImgBlocks) * B_X * imgsPerThread;
  const int blockFilterIdx = (blockIdx.y % numFilterBlocks) * filtersPerThread;
  
//  const int blockOutputIdx = blockOutputY * outputsX + blockOutputX;
  const int numOutputs = outputsX * outputsX;
  const int imgPixels = imgSize * imgSize;
  
  const int tidx = threadIdx.y * B_X + threadIdx.x;
  const int loadY = tidx / 32, loadX = tidx % 32;
  
  const int myX = threadIdx.y % 4;
  const int myY = threadIdx.y / 4;
  
  const int myOutputIdxY = blockOutputY + myY;
  const int myOutputIdxX = blockOutputX + myX;
  const int myOutputIdx = myOutputIdxY * outputsX + myOutputIdxX;
  
  const int startImgPxX = startX + blockOutputX;
  const int startImgPxY = startX + blockOutputY;
  const int endImgPxX = startImgPxX + subsX;
  const int endImgPxY = startImgPxY + subsX;
  
  const int myStartImgPxY = startImgPxY + myY;
  const int myStartImgPxX = startImgPxX + myX;
  const int myEndImgPxY = endImgPxY + myY;
  const int myEndImgPxX = endImgPxX + myX;

  const int loopStartY = MAX(startImgPxY, 0);
  const int loopStartX = MAX(startImgPxX, 0);
  const int loopEndY = MIN(imgSize, endImgPxY + 3);
  const int loopEndX = MIN(imgSize, endImgPxX + 3);

  const int imgIdx = blockImgIdx + threadIdx.x;
  
  imgs += (blockFilterIdx + loadY) * imgPixels * numImages + blockImgIdx + loadX;
  target += (blockFilterIdx * numOutputs + myOutputIdx) * numImages + imgIdx;
  
  float prod[filtersPerThread][imgsPerThread];
  #pragma unroll
  for (int f = 0; f < filtersPerThread; f++) {
    #pragma unroll
    for (int i = 0; i < imgsPerThread; i++) {
      prod[f][i] = agg.getBaseValue(); 
    }
  }

  for (int y = loopStartY; y < loopEndY; y++) {
    const bool isInY = y >= myStartImgPxY && y < myEndImgPxY ;
    for (int x = loopStartX; x < loopEndX; x++) {
      // Load a pixel
      const int px = y * imgSize + x;
      #pragma unroll
      for (int ly = 0; ly < filtersPerThread; ly += B_X/2) {
        if (filtersPerThread % (B_X/2) == 0 || ly + loadY < filtersPerThread) {
          #pragma unroll
          for (int lx = 0; lx < B_X*imgsPerThread; lx += 32) {
            if (!checkCaseBounds || lx + loadX + blockImgIdx < numImages) {
              shImgs[ly + loadY][lx + loadX] = imgs[(ly * imgPixels + px) * numImages + lx];
            }
          }
        }
      }
      __syncthreads();

      // Is this pixel in my region?
      if (isInY && x >= myStartImgPxX && x < myEndImgPxX) {
        #pragma unroll
        for (int i = 0; i < imgsPerThread; i++) {
          if (!checkCaseBounds || imgIdx + i * B_X < numImages) {
            #pragma unroll
            for (int f = 0; f < filtersPerThread; f++) {
              prod[f][i] = agg(prod[f][i], shImgs[f][threadIdx.x + i * B_X]);
            }
          }
        }
      }
      __syncthreads();

    }
  }
  if (myOutputIdxY < outputsX && myOutputIdxX < outputsX) {
    #pragma unroll
    for (int i = 0; i < imgsPerThread; i++) {
      if (!checkCaseBounds || imgIdx + i * B_X < numImages) {
        #pragma unroll
        for (int f = 0; f < filtersPerThread; f++) {
          target[f * numOutputs * numImages + i * B_X] = agg.output(prod[f][i]); 
        }
      }
    }
  }
}

/*
 * imgs:    (numFilters, imgPixels, numImages)
 * rnd:    (numFilters, imgPixels, numImages)
 * target:   (numFilters, outputs, numImages)
 */
template<class Pooler>
void convLocalProbPoolCu(cudamat* images, cudamat* rnd, cudamat* target, int numFilters, int subsX, int startX, int strideX, int outputsX, Pooler pooler) {
  int numImages = images->size[0];
  int imgPixels = images->size[1] / numFilters;
  assert(images->size[1] == numFilters * imgPixels);
  int imgSize = int(sqrt(imgPixels));
  assert(imgSize * imgSize == imgPixels);
  assert(rnd->size[0] == images->size[0]);
  assert(rnd->size[1] == images->size[1]);
  /*
  assert(!images.isTrans());
  assert(!target.isTrans());
  assert(images.isContiguous());
  */
//  assert(numFilters % 4 == 0);
//  assert(numImages % 128 == 0);
  
  // int outputs = outputsX * outputsX;
  //target.resize(numFilters*outputs, numImages);

  if (strideX == 1 && subsX >= 6) {
    int imgsPerThread = 8;
    int filtersPerThread = numFilters % 4 == 0 ? 4 : numFilters % 3 == 0 ? 3 : numFilters % 2 == 0 ? 2 : 1;
    int bx = 8;
    bool checkCaseBounds = numImages % (bx*imgsPerThread) != 0;
    assert((imgsPerThread * bx) % 32 == 0);
    assert(numFilters % filtersPerThread == 0);
    dim3 threads(bx, 16);
    dim3 blocks(DIVUP(outputsX, 4) * DIVUP(numImages, bx*imgsPerThread), DIVUP(outputsX, 4) * numFilters / filtersPerThread);
    if (filtersPerThread == 1) {
       if (checkCaseBounds) {
        cudaFuncSetCacheConfig(kLocalProbPool2<Pooler, 8, 8, 1, true>, cudaFuncCachePreferShared);
        kLocalProbPool2<Pooler, 8, 8, 1, true><<<blocks, threads>>>(images->data_device, rnd->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, outputsX, pooler);
      } else {
        cudaFuncSetCacheConfig(kLocalProbPool2<Pooler, 8, 8, 1, false>, cudaFuncCachePreferShared);
        kLocalProbPool2<Pooler, 8, 8, 1, false><<<blocks, threads>>>(images->data_device, rnd->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, outputsX, pooler);
      }
    } else if (filtersPerThread == 2) {
      if (checkCaseBounds) {
        cudaFuncSetCacheConfig(kLocalProbPool2<Pooler, 8, 8, 2, true>, cudaFuncCachePreferShared);
        kLocalProbPool2<Pooler, 8, 8, 2, true><<<blocks, threads>>>(images->data_device, rnd->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, outputsX, pooler);
      } else {
        cudaFuncSetCacheConfig(kLocalProbPool2<Pooler, 8, 8, 2, false>, cudaFuncCachePreferShared);
        kLocalProbPool2<Pooler, 8, 8, 2, false><<<blocks, threads>>>(images->data_device, rnd->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, outputsX, pooler);
      }
    } else if (filtersPerThread == 3) {
      if (checkCaseBounds) {
        cudaFuncSetCacheConfig(kLocalProbPool2<Pooler, 8, 8, 3, true>, cudaFuncCachePreferShared);
        kLocalProbPool2<Pooler, 8, 8, 3, true><<<blocks, threads>>>(images->data_device, rnd->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, outputsX, pooler);
      } else {
        cudaFuncSetCacheConfig(kLocalProbPool2<Pooler, 8, 8, 3, false>, cudaFuncCachePreferShared);
        kLocalProbPool2<Pooler, 8, 8, 3, false><<<blocks, threads>>>(images->data_device, rnd->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, outputsX, pooler);
      }
    } else if (filtersPerThread == 4) {
      if (checkCaseBounds) {
        cudaFuncSetCacheConfig(kLocalProbPool2<Pooler, 8, 8, 4, true>, cudaFuncCachePreferShared);
        kLocalProbPool2<Pooler, 8, 8, 4, true><<<blocks, threads>>>(images->data_device, rnd->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, outputsX, pooler);
      } else {
        cudaFuncSetCacheConfig(kLocalProbPool2<Pooler, 8, 8, 4, false>, cudaFuncCachePreferShared);
        kLocalProbPool2<Pooler, 8, 8, 4, false><<<blocks, threads>>>(images->data_device, rnd->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, outputsX, pooler);
      }
    }
  } else {
    bool checkCaseBounds = numImages % 128 != 0;
    int filtersPerThread = numFilters % 8 == 0 ? 2 : 1;
    dim3 threads(32, 4);
    dim3 blocks(DIVUP(numImages,32*4) * outputsX, DIVUP(numFilters, 4 * filtersPerThread) * outputsX);
    if (filtersPerThread == 1) {
      if (checkCaseBounds) {
        cudaFuncSetCacheConfig(kLocalProbPool<Pooler, 4, 32, 4, 1, true>, cudaFuncCachePreferL1);
        kLocalProbPool<Pooler, 4, 32, 4, 1, true><<<blocks, threads>>>(images->data_device, rnd->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, strideX, outputsX, pooler);
      } else {
        cudaFuncSetCacheConfig(kLocalProbPool<Pooler, 4, 32, 4, 1, false>, cudaFuncCachePreferL1);
        kLocalProbPool<Pooler, 4, 32, 4, 1, false><<<blocks, threads>>>(images->data_device, rnd->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, strideX, outputsX, pooler);
      }
    } else {
      if (checkCaseBounds) {
        cudaFuncSetCacheConfig(kLocalProbPool<Pooler, 4, 32, 4, 2, true>, cudaFuncCachePreferL1);
        kLocalProbPool<Pooler, 4, 32, 4, 2, true><<<blocks, threads>>>(images->data_device, rnd->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, strideX, outputsX, pooler);
      } else {
        cudaFuncSetCacheConfig(kLocalProbPool<Pooler, 4, 32, 4, 2, false>, cudaFuncCachePreferL1);
        kLocalProbPool<Pooler, 4, 32, 4, 2, false><<<blocks, threads>>>(images->data_device, rnd->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, strideX, outputsX, pooler);
      }
    }

  }

  getLastCudaError("convLocalPool: kernel execution failed");
}

/*
 * imgs:    (numFilters, imgPixels, numImages)
 * target:   (numFilters, outputs, numImages)
 */
template<class Pooler>
void convLocalPoolCu(cudamat* images, cudamat* target, int numFilters, int subsX, int startX, int strideX, int outputsX, Pooler pooler) {
  int numImages = images->size[0];
  int imgPixels = images->size[1] / numFilters;
  assert(images->size[1] == numFilters * imgPixels);
  int imgSize = int(sqrt(imgPixels));
  assert(imgSize * imgSize == imgPixels);
  
  /*
  assert(!images.isTrans());
  assert(!target.isTrans());
  assert(images.isContiguous());
  */
//  assert(numFilters % 4 == 0);
//  assert(numImages % 128 == 0);
  
  // int outputs = outputsX * outputsX;
  //target.resize(numFilters*outputs, numImages);

  if (strideX == 1 && subsX >= 6) {
    int imgsPerThread = 8;
    int filtersPerThread = numFilters % 4 == 0 ? 4 : numFilters % 3 == 0 ? 3 : numFilters % 2 == 0 ? 2 : 1;
    int bx = 8;
    bool checkCaseBounds = numImages % (bx*imgsPerThread) != 0;
    assert((imgsPerThread * bx) % 32 == 0);
    assert(numFilters % filtersPerThread == 0);
    dim3 threads(bx, 16);
    dim3 blocks(DIVUP(outputsX, 4) * DIVUP(numImages, bx*imgsPerThread), DIVUP(outputsX, 4) * numFilters / filtersPerThread);
    if (filtersPerThread == 1) {
       if (checkCaseBounds) {
        cudaFuncSetCacheConfig(kLocalPool2<Pooler, 8, 8, 1, true>, cudaFuncCachePreferShared);
        kLocalPool2<Pooler, 8, 8, 1, true><<<blocks, threads>>>(images->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, outputsX, pooler);
      } else {
        cudaFuncSetCacheConfig(kLocalPool2<Pooler, 8, 8, 1, false>, cudaFuncCachePreferShared);
        kLocalPool2<Pooler, 8, 8, 1, false><<<blocks, threads>>>(images->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, outputsX, pooler);
      }
    } else if (filtersPerThread == 2) {
      if (checkCaseBounds) {
        cudaFuncSetCacheConfig(kLocalPool2<Pooler, 8, 8, 2, true>, cudaFuncCachePreferShared);
        kLocalPool2<Pooler, 8, 8, 2, true><<<blocks, threads>>>(images->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, outputsX, pooler);
      } else {
        cudaFuncSetCacheConfig(kLocalPool2<Pooler, 8, 8, 2, false>, cudaFuncCachePreferShared);
        kLocalPool2<Pooler, 8, 8, 2, false><<<blocks, threads>>>(images->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, outputsX, pooler);
      }
    } else if (filtersPerThread == 3) {
      if (checkCaseBounds) {
        cudaFuncSetCacheConfig(kLocalPool2<Pooler, 8, 8, 3, true>, cudaFuncCachePreferShared);
        kLocalPool2<Pooler, 8, 8, 3, true><<<blocks, threads>>>(images->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, outputsX, pooler);
      } else {
        cudaFuncSetCacheConfig(kLocalPool2<Pooler, 8, 8, 3, false>, cudaFuncCachePreferShared);
        kLocalPool2<Pooler, 8, 8, 3, false><<<blocks, threads>>>(images->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, outputsX, pooler);
      }
    } else if (filtersPerThread == 4) {
      if (checkCaseBounds) {
        cudaFuncSetCacheConfig(kLocalPool2<Pooler, 8, 8, 4, true>, cudaFuncCachePreferShared);
        kLocalPool2<Pooler, 8, 8, 4, true><<<blocks, threads>>>(images->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, outputsX, pooler);
      } else {
        cudaFuncSetCacheConfig(kLocalPool2<Pooler, 8, 8, 4, false>, cudaFuncCachePreferShared);
        kLocalPool2<Pooler, 8, 8, 4, false><<<blocks, threads>>>(images->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, outputsX, pooler);
      }
    }
  } else {
    bool checkCaseBounds = numImages % 128 != 0;
    int filtersPerThread = numFilters % 8 == 0 ? 2 : 1;
    dim3 threads(32, 4);
    dim3 blocks(DIVUP(numImages,32*4) * outputsX, DIVUP(numFilters, 4 * filtersPerThread) * outputsX);
    if (filtersPerThread == 1) {
      if (checkCaseBounds) {
        cudaFuncSetCacheConfig(kLocalPool<Pooler, 4, 32, 4, 1, true>, cudaFuncCachePreferL1);
        kLocalPool<Pooler, 4, 32, 4, 1, true><<<blocks, threads>>>(images->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, strideX, outputsX, pooler);
      } else {
        cudaFuncSetCacheConfig(kLocalPool<Pooler, 4, 32, 4, 1, false>, cudaFuncCachePreferL1);
        kLocalPool<Pooler, 4, 32, 4, 1, false><<<blocks, threads>>>(images->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, strideX, outputsX, pooler);
      }
    } else {
      if (checkCaseBounds) {
        cudaFuncSetCacheConfig(kLocalPool<Pooler, 4, 32, 4, 2, true>, cudaFuncCachePreferL1);
        kLocalPool<Pooler, 4, 32, 4, 2, true><<<blocks, threads>>>(images->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, strideX, outputsX, pooler);
      } else {
        cudaFuncSetCacheConfig(kLocalPool<Pooler, 4, 32, 4, 2, false>, cudaFuncCachePreferL1);
        kLocalPool<Pooler, 4, 32, 4, 2, false><<<blocks, threads>>>(images->data_device, target->data_device,
                                 imgSize, numFilters, numImages, subsX, startX, strideX, outputsX, pooler);
      }
    }

  }

  getLastCudaError("convLocalPool: kernel execution failed");
}
